import torch
from torch import nn
import math
import time

device = 'cuda' if torch.cuda.is_available() else 'cpu'

def cheb_poly(K, l, mu, v, grad_y, params_y):

    mu1 = mu / (2*l)
    l1 = 0.5

    p1 = 2 / (l1 - mu1)
    p2 = (l1 + mu1) / (l1 - mu1)
    p3 = (math.sqrt(mu1 / l1) - 1) / (math.sqrt(mu1 / l1) + 1)
    c = 2 / math.sqrt(l1*mu1)

    Hv = nn.utils.parameters_to_vector(torch.autograd.grad(grad_y, params_y, grad_outputs=v, retain_graph=True))
    Hv = -Hv / (2*l)

    T = [v, p1 * Hv - p2 * v]    # T_0 & T_1

    result = c / 2 * T[0]
    c = c * p3
    result = result + c * T[1]

    index = 0

    for i in range(2, K):
        c = c * p3

        Hv = nn.utils.parameters_to_vector(torch.autograd.grad(grad_y, params_y, grad_outputs=T[1-index], retain_graph=True))
        Hv = -Hv / (2 * l)
        T[index] = 2 * p1 * Hv - 2 * p2 * T[1-index] - T[index]

        result = result + c * T[index]
        index = 1 - index


    return result / (2*l)

def cheb_poly3(K, l, mu, v, hess):


    mu1 = mu / (2*l)
    l1 = 0.5

    p1 = 2 / (l1 - mu1)
    p2 = (l1 + mu1) / (l1 - mu1)
    p3 = (math.sqrt(mu1 / l1) - 1) / (math.sqrt(mu1 / l1) + 1)
    c = 2 / math.sqrt(l1*mu1)


    Hv = (hess @ v) / (2*l)


    T = [v, p1 * Hv - p2 * v]    # T_0 & T_1

    result = c / 2 * T[0]
    c = c * p3
    result = result + c * T[1]

    index = 0

    #print(result / (2*l))

    for i in range(2, K):
        c = c * p3

        Hv = (hess @ T[1-index]) / (2 * l)
        T[index] = 2 * p1 * Hv - 2 * p2 * T[1-index] - T[index]

        result = result + c * T[index]
        index = 1 - index

        #print(result / (2*l))
    return result / (2*l)

def cheb_poly2(K, l, mu, v, s_fea, t_fea, s_tmp, t_tmp):

    def HessianVector(v):
        return (s_tmp @ (s_fea @ v) + t_tmp @ (t_fea @ v) + mu * v) / (2 * l)


    mu1 = mu / (2*l)
    l1 = 0.5

    p1 = 2 / (l1 - mu1)
    p2 = (l1 + mu1) / (l1 - mu1)
    p3 = (math.sqrt(mu1 / l1) - 1) / (math.sqrt(mu1 / l1) + 1)
    c = 2 / math.sqrt(l1*mu1)

    #print('p1=', p1, 'p2=', p2, 'p3=', p3, 'c=', c)

    #Hv = (-hess @ v) / (2*l)
    Hv = HessianVector(v)

    T = [v, p1 * Hv - p2 * v]    # T_0 & T_1

    result = c / 2 * T[0]
    c = c * p3
    result = result + c * T[1]

    index = 0

    #print(result / (2*l))

    for i in range(2, K):
        c = c * p3

        #Hv = (-hess @ T[1-index]) / (2 * l)
        Hv = HessianVector(T[1-index])
        T[index] = 2 * p1 * Hv - 2 * p2 * T[1-index] - T[index]

        result = result + c * T[index]
        index = 1 - index

        #print(result / (2*l))


    return result / (2*l)



def hessian_vector_product(grad_x, grad_y, x, params_y, v, lam, l_est):
    x_num = len(grad_x)
    hv_xx = nn.utils.parameters_to_vector(torch.autograd.grad(grad_x, x.parameters(), grad_outputs=v, retain_graph=True))
    hv_xy = nn.utils.parameters_to_vector(torch.autograd.grad(grad_x, params_y, grad_outputs=v, retain_graph=True))


    v1 = cheb_poly(min(max(math.ceil(math.sqrt(l_est/lam)),5),30), l_est, lam, hv_xy, grad_y, params_y)
    #print(v1)
    #hv_yx = nn.utils.parameters_to_vector(torch.autograd.grad(grad_y, x.parameters(), grad_outputs=v1, retain_graph=True))
    hv_yx = torch.autograd.grad(grad_y, x.parameters(), grad_outputs=v1, retain_graph=True, allow_unused=True)
    hv_yx = nn.utils.parameters_to_vector(hv_yx[0:2])
    hv_yx = torch.cat((hv_yx, torch.zeros(x_num - len(hv_yx)).to(device)))

    return hv_xx + hv_yx


def hessian_vector_product3(grad_x, grad_y, x, params_y, s_fea, t_fea, s_tmp, t_tmp, v, lam, l_est):
    x_num = len(grad_x)


    hv_xx = nn.utils.parameters_to_vector(torch.autograd.grad(grad_x, x.parameters(), grad_outputs=v, retain_graph=True))

    # if torch.isnan(hv_xx).any():
    #     print('hv_xx nan')
    #     b = input()

    hv_xy = nn.utils.parameters_to_vector(torch.autograd.grad(grad_x, params_y, grad_outputs=v, retain_graph=True))

    #e, _ = torch.linalg.eigh(-hess_y)
    #l, mu = e[-1], e[0]

    v1 = cheb_poly2(min(max(math.ceil(math.sqrt(l_est/lam)),5),30), l_est, lam, hv_xy, s_fea, t_fea, s_tmp, t_tmp)
    #v1 = cheb_poly2(20, 100, lam, hv_xy, s_fea, t_fea, s_tmp, t_tmp)
    #v3 = cheb_poly(10, 5, 1, hv_xy, grad_y, params_y)


    hv_yx = torch.autograd.grad(grad_y, x.parameters(), grad_outputs=v1, retain_graph=True, allow_unused=True)
    hv_yx = nn.utils.parameters_to_vector(hv_yx[0:2])
    hv_yx = torch.cat((hv_yx,torch.zeros(x_num-len(hv_yx)).to(device)))


    #hv_yx = nn.utils.parameters_to_vector(torch.autograd.grad(grad_y, params_x, grad_outputs=v1, retain_graph=True))

    return hv_xx + hv_yx

def hessian_vector_product2(supp, grad_x, grad_y, x, params_y, hess_y, v, l, mu, chebepoch):
    x_num = len(grad_x)


    hv_xx = nn.utils.parameters_to_vector(torch.autograd.grad(grad_x, x.parameters(), grad_outputs=v, retain_graph=True))
    hv_xy = nn.utils.parameters_to_vector(torch.autograd.grad(grad_x, params_y, grad_outputs=v, retain_graph=True))

    #v1 = torch.linalg.lstsq(-hess_y, hv_xy).solution

    #e, _ = torch.linalg.eigh(-hess_y)
    #l, mu = e[-1], e[0]

    v1 = cheb_poly3(chebepoch, l, mu, hv_xy, hess_y)
    #v3 = cheb_poly(10, 5, 1, hv_xy, grad_y, params_y)


    hv_yx = torch.autograd.grad(grad_y, x.parameters(), grad_outputs=v1, retain_graph=True, allow_unused=True)

    hv_yx = torch.cat((nn.utils.parameters_to_vector(hv_yx[0:2]), supp))
    return hv_xx + hv_yx
